{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "57d23577",
   "metadata": {},
   "source": [
    "# [8、深入剖析PyTorch的state_dict、parameters、modules源码](https://www.bilibili.com/video/BV12g411K7oq?spm_id_from=333.788.player.switch&vd_source=cdd897fffb54b70b076681c3c4e4d45d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1e93ed0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.models as models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dd838b6",
   "metadata": {},
   "source": [
    "## save and load\n",
    "\n",
    "规范的做法是保存：`当前epoch、模型权重、优化器权重、loss`\n",
    "如果只保存参数，那么只能够用来推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29faec2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.vgg16()\n",
    "# 仅保存模型的权重到文件中\n",
    "torch.save(model.state_dict(), 'model_weights.pth')\n",
    "# 保存所有\n",
    "torch.save(model, 'model.pth')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0653c2ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载模型\n",
    "weights = torch.load('model_weights.pth')\n",
    "model = models.vgg16()\n",
    "model.load_state_dict(weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3250b2be",
   "metadata": {},
   "source": [
    "## module.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6777c383",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'linear1': Linear(in_features=2, out_features=3, bias=True),\n",
       " 'linear2': Linear(in_features=3, out_features=4, bias=True),\n",
       " 'batch_norm': BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch import Tensor\n",
    "\n",
    "\n",
    "class MyNet(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyNet, self).__init__()\n",
    "        self.linear1 = torch.nn.Linear(2,3)\n",
    "        self.linear2 = torch.nn.Linear(3,4)\n",
    "        self.batch_norm = torch.nn.BatchNorm2d(4)\n",
    "        \n",
    "\n",
    "    def forward(self, x: Tensor):\n",
    "        return x\n",
    "    \n",
    "net = MyNet()\n",
    "net._modules # "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b8ae2549",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float32\n",
      "torch.bfloat16\n"
     ]
    }
   ],
   "source": [
    "print(net._modules['linear2'].weight.dtype)\n",
    "# 使用 to 函数 改变参数类型\n",
    "net.to(torch.bfloat16)\n",
    "print(net._modules['linear2'].weight.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe940c32",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1c7f4eae",
   "metadata": {},
   "source": [
    "## container.py\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
